import random
import numpy as np

from allennlp.predictors import Predictor

from config import Config
from tools.utils import pos_tags
from .candidates import WordNetCandidate, IBPCandidates
from .attacker import BlackBoxAttacker


class GeneticAttack(BlackBoxAttacker):
    def __init__(self, cf: Config, predictor: Predictor):
        super(GeneticAttack, self).__init__(cf, predictor)
        self.synonym_candidate = WordNetCandidate(self.supported_postag)

        self.population_size = 40
        self.generation_times = 20
        self.text = None

    def forward(self, text):
        text['sentence'] = [t.text for t in self.tokenizer.tokenize(text['sentence'])]
        text['tag'] = pos_tags([t for t in text['sentence']])
        self.text = text

        generation_g = []
        best_individual, success = text, False
        for i in range(self.population_size):
            generation_g.append(self.perturb(text))
        for g in range(self.generation_times):
            output_g = self.predict_batch_data(generation_g)
            scores = self.fitness(output_g)

            # best_idx = np.argmax(scores)
            # best_individual = generation_g[best_idx]
            idx, individual = self.select_best_individual(generation_g, scores)

            if idx is None:
                break
            best_idx = idx
            best_individual = individual
            success = self.stop([output_g[best_idx]], stop_condition='flip')[0]

            if self.stop([output_g[best_idx]])[0]:
                break

            generation_h = [best_individual]
            for i in range(self.population_size - 1):
                parents = np.random.choice(generation_g, 2, False, scores)
                child = self.crossover(parents)
                generation_h.append(self.perturb(child))

            generation_g = generation_h

        return self.attack_result(success=success,
                                  length=self.modified_length(text, best_individual) / len(text['sentence']),
                                  adv_example=best_individual)

    def select_best_individual(self, generation_g, scores):
        id2scores = [(i, s) for i, s in enumerate(scores)]
        id2scores.sort(key=lambda i: i[1], reverse=True)
        for (i, s) in id2scores:
            if self.modified_length(self.text, generation_g[i]) <= self.attack_num(len(self.text['sentence'])):
                return i, generation_g[i]
        return None, None

    def perturb(self, text):
        ids = range(len(text['sentence']))
        i, candidate = None, None
        valid_ids = list(ids)
        while candidate is None and len(valid_ids) > 0:
            i = random.choice(valid_ids)
            valid_ids.remove(i)
            synonyms = self.synonym_candidate.candidate_set(text['sentence'][i], text['tag'][i])
            if len(synonyms) == 0:
                continue
            candidate = random.choice(synonyms)

        return self.subsitude(text, i, candidate)

    def fitness(self, outputs):

        scores = [1 - o['gold_prob'] + 1e-6 for o in outputs]
        sum_score = sum(scores)

        scores = [s / sum_score for s in scores]
        return scores

    def crossover(self, parents):
        parent1, parent2 = parents[0], parents[1]
        child = self.copy_text(parent1)
        child['sentence'] = [random.choice([w1, w2]) for w1, w2 in zip(parent1['sentence'], parent2['sentence'])]
        return child

    def mutate(self):
        pass

    def modified_length(self, text, adv_text):
        count = 0
        for w, adv_w in zip(text['sentence'], adv_text['sentence']):
            if w != adv_w:
                count += 1
        return count
